type pl0_terminal = 
    IDENTIFIER
  | NUMBER 
  | CONST 
  | VAR 
  | PROCEDURE 
  | CALL 
  | BEGIN 
  | END 
  | IF 
  | THEN 
  | WHILE 
  | DO 
  | ODD 
  | READ 
  | WRITE 
  | CONST_ASSIGN_OP 
  | IDENTIFIER_ASSIGN_OP 
  | NUMERICAL_OP_MULDIV 
  | NUMERICAL_OP_PLUSMINUS 
  | RELATIONAL_OP 
  | DOT 
  | COMMA 
  | SEMICOLON 
  | LEFT_PARENTHESIS 
  | RIGHT_PARENTHESIS

type pl0_nonterminal = 
    PROG 
  | BLOCK 
  | OPT_CONST_DECLARE 
  | CONST_DECLARE 
  | CONST_ASSIGN 
  | OPT_VAR_DECLARE 
  | IDENTIFIER_LIST 
  | OPT_PROCEDURES 
  | OPT_STATEMENT 
  | STATEMENT 
  | STATEMENT_LIST 
  | EXPRESSION_LIST 
  | CONDITION 
  | EXPRESSION 
  | ADD_SUBTRACT_EXPRESSION 
  | TERM 
  | FACTOR 
  | EXTRA_START_SYMBOL

type symbol = 
    T of pl0_terminal 
  | NT of pl0_nonterminal 
  | Empty 
  | EOF

type symseq = symbol list

type production_srhs = {lhs: pl0_nonterminal; rhs: symseq}

type production_mrhs = {lhs: pl0_nonterminal; rhs_m: symseq list}

type first_set_table = (symbol * symbol list) list

type lr1_item = {prod: production_srhs; pos: int; lookaheads: symbol list}

type itemlist_node = {number: int; flag: bool; itemlist: lr1_item list; outedges: (symbol * int) list}

type lr1_parsing_action =
    Shift of int
  | Reduce of int
  | ParseError
  | SentenceAccept

type lr1_goto_table_entry =
    Goto of int
  | GotoError

type lr1_action_table = ((int * ((symbol * lr1_parsing_action) list)) list)

type lr1_goto_table = ((int * ((symbol * lr1_goto_table_entry) list)) list)

(* 1. Calculate first sets table: *)

(** unique combine *)
let ( @| ) lst1 lst2 =
  List.fold_left (fun x y -> if not (List.mem y x) then x@[y] else x) lst1 lst2

let first_set_table_scale (table: first_set_table) = 
  List.fold_left (fun x (_, lst) -> x + List.length lst) 0 table

let first_not_nterm_syms_of_symseq symseq ref_table = 
  let f_fold (l: symbol list) (r: symbol) : symbol list = 
    match List.mem Empty l with 
    | true -> (List.filter (fun b -> not (b = Empty)) l) @| List.assoc r ref_table
    | false -> l
  in List.fold_left f_fold [Empty] symseq

let first_not_nterm_syms_of_multiple_symseq symseq_lst ref_table = 
  List.fold_left (fun l r -> l @| (first_not_nterm_syms_of_symseq r ref_table)) [] symseq_lst

let updated_nonterm_first_sets_once (old_table: first_set_table) (productions: production_mrhs list) = 
  let f_fold oldtable {lhs; rhs_m} =
    let f_map (sym, symlst) = 
      if sym = NT lhs then 
        (sym, symlst @| (first_not_nterm_syms_of_multiple_symseq rhs_m oldtable)) 
      else 
        (sym, symlst) 
    in
    List.map f_map oldtable
  in
  List.fold_left f_fold old_table productions

let rec update_nonterm_first_sets_rec (productions: production_mrhs list) (table: first_set_table) old_scale =
  let new_table = updated_nonterm_first_sets_once table productions in
  let new_scale = first_set_table_scale new_table in
  let () = Printf.printf "First set: Old Scale %d -> New Scale %d\n" old_scale new_scale in
  if new_scale > old_scale then 
    update_nonterm_first_sets_rec productions new_table new_scale
  else 
    new_table

let build_initial_first_set_table (symbols: symbol list) = 
  let mapf x = 
    match x with 
    | NT _ -> (x, []) 
    | _ -> (x, [x]) 
  in 
  List.map mapf symbols

let build_first_set_table (symbols: symbol list) (productions: production_mrhs list) =
  let init_table = build_initial_first_set_table symbols in
  let init_scale = first_set_table_scale init_table in
  update_nonterm_first_sets_rec productions init_table init_scale

(* 2. Calculate closure: *)

(** item list merge *)
let ( @|!| ) itemlst1 itemlst2 = List.fold_left
  (fun ilst1 item_in_lst2 ->
    match item_in_lst2 with {prod=item_prod; pos=item_pos; lookaheads=item_lookaheads} -> 
    begin
      match (List.find_opt (fun {prod;pos;_} -> prod = item_prod && pos = item_pos) ilst1) with
      | Some _ -> 
        begin
          List.map 
          (fun {prod;pos;lookaheads} -> if prod = item_prod && pos = item_pos then
            {prod=prod;pos=pos;lookaheads=item_lookaheads @| lookaheads}
           else 
            {prod;pos;lookaheads}) 
          ilst1
        end
      | None -> ilst1 @ [item_in_lst2]
    end
  )
  itemlst1
  itemlst2

let itemlist_scale items = List.fold_left (fun x {prod=_; pos=_; lookaheads} -> x + List.length lookaheads) 0 items

let update_closure_once (items: lr1_item list) (productions: production_mrhs list) (ref_first_set_table: first_set_table) : lr1_item list =
  let fold_f itemlist {prod={lhs; rhs=symseq}; pos; lookaheads} = 
    match (List.nth_opt symseq pos) with
    | Some NT nterm_x -> 
      begin
        let symseq_follow_x_within_prod = List.filteri (fun i _ -> i > pos) symseq in
        let symseq_follow_x_combined = symseq_follow_x_within_prod @ lookaheads in
        let {lhs=_; rhs_m=x_prod_rhs_m} = List.find (fun {lhs;rhs_m=_} -> lhs = nterm_x) productions in
        itemlist @|!| (List.map 
        (fun symseq -> 
          {
            prod = {lhs = nterm_x;rhs = symseq}; 
            pos = if symseq = [Empty] then 1 else 0; 
            lookaheads =
            if symseq_follow_x_within_prod = [] then 
              lookaheads
            else
              (first_not_nterm_syms_of_symseq symseq_follow_x_combined ref_first_set_table)
          }
        )
        x_prod_rhs_m)
      end
    | _ -> itemlist
  in
    List.fold_left fold_f items items

let rec update_closure_rec items productions ref_first_set_table old_scale : lr1_item list =
  let updated = update_closure_once items productions ref_first_set_table in
  let new_scale = itemlist_scale updated in
  let () = Printf.printf "Closure: Old Scale %d -> New Scale %d\n" old_scale new_scale in
  if new_scale > old_scale then
    update_closure_rec updated productions ref_first_set_table new_scale
  else
    updated

let closure productions ref_first_set_table items = 
  let init_scale = itemlist_scale items in
  update_closure_rec items productions ref_first_set_table init_scale

(* 3. Goto function: *)

let goto_without_closure items symbol = 
  List.fold_left
  (fun itemlst item -> match item with {prod={lhs=_;rhs};pos;lookaheads=_} ->
    match List.nth_opt rhs pos with
    | Some s -> if s = symbol then itemlst @ [{item with pos=pos+1}] else itemlst
    | None -> itemlst
  )
  []
  items

(* 4. Build Canonical Collection: *)

(** add an new node corresponding to [itemlst] to a canonical collection only if such an node does not exist already.
    return a tuple whose first element is the updated (or original) collection and the second element is the number 
    of newly added node (or the already-existing node containing [itemlst]). *)
let add_distinct_to_collection (itemlst: lr1_item list) (collection: itemlist_node list) = 
  let number_itemlst_opt =
  List.map (fun {number;itemlist} -> (itemlist, number)) collection |> List.assoc_opt itemlst
  in
  match number_itemlst_opt with
  | Some no -> (collection, no)
  | None -> let new_no = List.length collection in
    ({number=new_no; flag=false; itemlist=itemlst; outedges=[]} :: collection, new_no)

let update_canonical_collection_once 
(collection: itemlist_node list)
(productions: production_mrhs list) 
(ref_first_set_table: first_set_table) : itemlist_node list =
  List.fold_left
  (fun cur_coll {number=x_no;flag=x_marked;itemlist=x_itemlist} -> 
    if x_marked then
      cur_coll
    else
      begin
      let (cur_coll_with_goto_x_added, x_new_outedges) =
        List.fold_left
        (fun (coll_new, outedges) item -> 
          match item with {prod={lhs=_;rhs=item_rhs}; pos; lookaheads=_} ->
          match (List.nth_opt item_rhs pos) with
          | Some sym -> 
            begin
              match sym with
              | T _ | NT _ -> 
                begin
                  if (List.assoc_opt sym outedges) = None then
                    let x_goto_on_sym = 
                      goto_without_closure x_itemlist sym |> closure productions ref_first_set_table 
                    in
                      let (coll, node_no) = add_distinct_to_collection x_goto_on_sym coll_new in
                      (coll, (sym, node_no)::outedges)
                  else
                    (coll_new, outedges)
                end
              | _ -> (coll_new, outedges)
            end
          | None -> (coll_new, outedges)
        )
        (cur_coll, [])
        x_itemlist
      in
        List.map 
        (fun node -> match node with {number} -> 
          if number = x_no then {node with flag = true; outedges = x_new_outedges} else node
        ) 
        cur_coll_with_goto_x_added
      end
  )
  collection
  collection

let rec update_canonical_collection_rec collection productions ref_first_set_table old_scale =
  let new_coll = update_canonical_collection_once collection productions ref_first_set_table in
  let new_scale = List.length new_coll in
  let () = Printf.printf "Canonical Coll: Old Scale %d -> New Scale %d\n" old_scale new_scale in
  if new_scale > old_scale then
    update_canonical_collection_rec new_coll productions ref_first_set_table new_scale
  else
    new_coll

let canonical_collection initial_item productions ref_first_set_table =
  let initial_item_closure = closure productions ref_first_set_table [initial_item] in
  let initial_coll = [{number=0; flag=false; outedges=[]; itemlist=initial_item_closure}] in
  let initial_scale = 1 in
  update_canonical_collection_rec initial_coll productions ref_first_set_table initial_scale

(* 5. Build LR1 Parsing tables *)

(** replace the default value of (k, v) pair of an association list with a new (non-default) value.
    @raise Failure if the (k, v) pair has already been filled with non-default value).  *)
let assoc_list_replace_default assoclst default key rep =
  List.map 
  (fun (k, v) -> 
    if k = key then 
      if v = default then (k, rep) else raise (Failure "Already Filled") 
    else 
      (k, v)
  ) 
  assoclst

(** replace the default values of (k, v) pairs of an association list with a new (non-default) value.
    @raise Failure if one of the (k, v) pairs has already been filled with non-default value).  *)
let assoc_list_replace_default_multikeys assoclst default keys rep =
  List.map 
  (fun (k, v) -> 
    if List.mem k keys then 
      if v = default then (k, rep) else raise (Failure "Already Filled") 
    else 
      (k, v)
  ) 
  assoclst

let build_lr1_parsing_tables
(symbols: symbol list)
(productions_mrhs: production_mrhs list)
(ref_canonical_collection: itemlist_node list) 
: ((production_srhs list) * lr1_action_table * lr1_goto_table) =
  let (action_table_row_init, goto_table_row_init) =
    List.fold_left 
    (fun (action_row, goto_row) s -> 
      match s with 
      | NT _ -> (action_row, (s, GotoError)::goto_row)
      | T _ | EOF -> ((s, ParseError)::action_row, goto_row)
      | _ -> (action_row, goto_row)
    ) 
    ([], []) 
    symbols 
  in
  let productions_srhs = 
    List.fold_left 
    (fun prods_srhs {lhs;rhs_m} -> 
      prods_srhs @ (List.map (fun symseq -> {lhs=lhs;rhs=symseq}) rhs_m)
    ) 
    [] 
    productions_mrhs
  in
  let production_srhs_to_idx_map = 
    List.mapi (fun i prod -> (prod, i)) productions_srhs
  in
  let (action_table, goto_table) =
    List.fold_left
    (fun (acc_action_table, acc_goto_table) {number=x_no;itemlist=x_itemlst;outedges=x_outedges} -> 
      let (action_table_row_for_x, goto_table_row_for_x) =
        let action_row_reduce_accept_filled =
          List.fold_left
          (fun acc_action_table_row {prod={lhs;rhs};pos;lookaheads} ->
            match List.nth_opt rhs pos with
            | None -> 
              if lhs = EXTRA_START_SYMBOL then
                assoc_list_replace_default acc_action_table_row ParseError EOF SentenceAccept
              else
                Reduce (List.assoc {lhs=lhs;rhs=rhs} production_srhs_to_idx_map) |>
                assoc_list_replace_default_multikeys acc_action_table_row ParseError lookaheads 
            | _ -> acc_action_table_row
          )
          action_table_row_init
          x_itemlst
        in
          List.fold_left
          (fun (acc_action_table_row, acc_goto_table_row) (sym, dest) ->
            match sym with
            | NT _ -> (acc_action_table_row, (assoc_list_replace_default acc_goto_table_row GotoError sym (Goto dest)))
            | T _ -> ((assoc_list_replace_default acc_action_table_row ParseError sym (Shift dest)), acc_goto_table_row)
            | _ -> (acc_action_table_row, acc_goto_table_row)
          )
          (action_row_reduce_accept_filled, goto_table_row_init)
          x_outedges
      in
        ((x_no, action_table_row_for_x)::acc_action_table, (x_no, goto_table_row_for_x)::acc_goto_table)
    ) 
    ([], [])
    ref_canonical_collection
  in
  (productions_srhs, action_table, goto_table)

(* 6. Go! *)

let all_symbols = [
  T IDENTIFIER;
  T NUMBER;
  T CONST;
  T VAR;
  T PROCEDURE;
  T CALL;
  T BEGIN;
  T END;
  T IF;
  T THEN;
  T WHILE;
  T DO;
  T ODD;
  T READ;
  T WRITE;
  T CONST_ASSIGN_OP;
  T IDENTIFIER_ASSIGN_OP;
  T NUMERICAL_OP_MULDIV;
  T NUMERICAL_OP_PLUSMINUS;
  T RELATIONAL_OP;
  T DOT;
  T COMMA;
  T SEMICOLON;
  T LEFT_PARENTHESIS;
  T RIGHT_PARENTHESIS;
  NT PROG;
  NT BLOCK;
  NT OPT_CONST_DECLARE;
  NT CONST_DECLARE;
  NT CONST_ASSIGN;
  NT OPT_VAR_DECLARE;
  NT IDENTIFIER_LIST;
  NT OPT_PROCEDURES;
  NT OPT_STATEMENT;
  NT STATEMENT;
  NT STATEMENT_LIST;
  NT EXPRESSION_LIST;
  NT CONDITION;
  NT EXPRESSION;
  NT ADD_SUBTRACT_EXPRESSION;
  NT TERM;
  NT FACTOR;
  NT EXTRA_START_SYMBOL;
  Empty;
  EOF
]

let all_sym_to_int_map = 
  List.mapi 
  (fun i sym -> 
    let number =
      match sym with
      | NT EXTRA_START_SYMBOL -> -1
      | Empty -> -2
      | EOF -> -3
      | _ -> i
    in 
      (sym, number)
  )  
  all_symbols

let output_production_srhs_in_ints_to {lhs;rhs} oc =
  Printf.fprintf oc "%d:" (List.assoc (NT lhs) all_sym_to_int_map);
  (*List.iter (fun sym -> Printf.fprintf oc "%d|" (List.assoc sym all_sym_to_int_map)) rhs;*)
  Printf.fprintf oc "%d" (if rhs = [Empty] then 0 else List.length rhs)

let write_lr1_parsing_tables_to_files
(production_list: production_srhs list)
(action_table: lr1_action_table)
(goto_table: lr1_goto_table)
(produciton_list_out_filename: string)
(action_table_out_filename: string)
(goto_table_out_filename: string) : unit =
  let prod_lst_file = open_out produciton_list_out_filename in
  let action_tbl_file = open_out action_table_out_filename in
  let goto_tbl_file = open_out goto_table_out_filename in

  List.iteri 
    (fun i prod -> 
      Printf.fprintf prod_lst_file "%d, " i;
      output_production_srhs_in_ints_to prod prod_lst_file;
      Printf.fprintf prod_lst_file "\n"
    ) 
    production_list;

  Printf.fprintf action_tbl_file "S, ";
  let (_, one_row) = List.hd action_table in
    List.iter
    (fun (sym, _) ->
      Printf.fprintf action_tbl_file "%d, " (List.assoc sym all_sym_to_int_map)
    )
    one_row; 
  Printf.fprintf action_tbl_file "\n";
  List.iter
  (fun (state_number, action_tbl_row) ->
    Printf.fprintf action_tbl_file "%d, " state_number;
    List.iter
    (fun (_, action) ->
      let s =
        match action with
        | ParseError -> "E"
        | SentenceAccept -> "A"
        | Shift s -> "S" ^ (string_of_int s)
        | Reduce r -> "R" ^ (string_of_int r)
      in 
      Printf.fprintf action_tbl_file "\"%s\", " s
    )
    action_tbl_row;
    Printf.fprintf action_tbl_file "\n";
  )
  action_table;

  Printf.fprintf goto_tbl_file "S, ";
  let (_, one_row) = List.hd goto_table in
    List.iter
    (fun (sym, _) ->
      Printf.fprintf goto_tbl_file "%d, " (List.assoc sym all_sym_to_int_map)
    )
    one_row; 
  Printf.fprintf goto_tbl_file "\n";
  List.iter
  (fun (state_number, goto_table_row) ->
    Printf.fprintf goto_tbl_file "%d, " state_number;
    List.iter
    (fun (_, action) ->
      let s =
        match action with
        | GotoError -> "E"
        | Goto g -> "G" ^ (string_of_int g)
      in 
      Printf.fprintf goto_tbl_file "\"%s\", " s
    )
    goto_table_row;
    Printf.fprintf goto_tbl_file "\n";
  )
  goto_table;

  close_out prod_lst_file;
  close_out action_tbl_file;
  close_out goto_tbl_file

let production_mrhs_list = [
  {lhs=EXTRA_START_SYMBOL; rhs_m=[[NT PROG]]};
  {lhs=PROG; rhs_m=[[NT BLOCK; T DOT]]};
  {lhs=BLOCK; rhs_m=[[NT OPT_CONST_DECLARE; NT OPT_VAR_DECLARE; NT OPT_PROCEDURES; NT OPT_STATEMENT]]};

  {lhs=OPT_CONST_DECLARE; rhs_m=[[Empty];[NT CONST_DECLARE; T SEMICOLON]]};
  {lhs=CONST_DECLARE; rhs_m=[[T CONST; NT CONST_ASSIGN];[NT CONST_DECLARE; T COMMA; NT CONST_ASSIGN]]};
  {lhs=CONST_ASSIGN; rhs_m=[[T IDENTIFIER; T CONST_ASSIGN_OP; T NUMBER]]};
  {lhs=OPT_VAR_DECLARE; rhs_m=[[Empty];[T VAR; NT IDENTIFIER_LIST; T SEMICOLON]]};
  {lhs=IDENTIFIER_LIST; rhs_m=[[T IDENTIFIER];[NT IDENTIFIER_LIST; T COMMA; T IDENTIFIER]]};

  {lhs=OPT_PROCEDURES; rhs_m=[[Empty];[NT OPT_PROCEDURES; T PROCEDURE; T IDENTIFIER; T SEMICOLON; NT BLOCK; T SEMICOLON]]};

  {lhs=OPT_STATEMENT; rhs_m=[[Empty];[NT STATEMENT]]};
  {lhs=STATEMENT; rhs_m=[
    [T IDENTIFIER; T IDENTIFIER_ASSIGN_OP; NT EXPRESSION];
    [T CALL; T IDENTIFIER];
    [T BEGIN; NT STATEMENT_LIST; T END];
    [T IF; NT CONDITION; T THEN; NT STATEMENT];
    [T WHILE; NT CONDITION; T DO; NT STATEMENT];
    [T READ; T LEFT_PARENTHESIS; NT IDENTIFIER_LIST; T RIGHT_PARENTHESIS];
    [T WRITE; T LEFT_PARENTHESIS; NT EXPRESSION_LIST; T RIGHT_PARENTHESIS]]};
  {lhs=STATEMENT_LIST; rhs_m=[[NT OPT_STATEMENT];[NT STATEMENT_LIST; T SEMICOLON; NT OPT_STATEMENT]]};
  {lhs=EXPRESSION_LIST; rhs_m=[[NT EXPRESSION];[NT EXPRESSION_LIST; T COMMA; NT EXPRESSION]]};
  {lhs=CONDITION; rhs_m=[[T ODD; NT EXPRESSION];[NT EXPRESSION; T RELATIONAL_OP; NT EXPRESSION]]};
  {lhs=EXPRESSION; rhs_m=[[NT ADD_SUBTRACT_EXPRESSION];[T NUMERICAL_OP_PLUSMINUS; NT ADD_SUBTRACT_EXPRESSION]]};
  {lhs=ADD_SUBTRACT_EXPRESSION; rhs_m=[[NT TERM];[NT ADD_SUBTRACT_EXPRESSION; T NUMERICAL_OP_PLUSMINUS; NT TERM]]};
  {lhs=TERM; rhs_m=[[NT FACTOR];[NT TERM; T NUMERICAL_OP_MULDIV; NT FACTOR]]};
  {lhs=FACTOR; rhs_m=[[T IDENTIFIER];[T NUMBER];[T LEFT_PARENTHESIS; NT EXPRESSION; T RIGHT_PARENTHESIS]]};
]

let pl0_sym_first_set_table = build_first_set_table all_symbols production_mrhs_list
let ini_item = {prod = {lhs=EXTRA_START_SYMBOL; rhs=[NT PROG]}; pos = 0; lookaheads = [EOF]}

let cc = canonical_collection ini_item production_mrhs_list pl0_sym_first_set_table

let (p, a, g) = build_lr1_parsing_tables all_symbols production_mrhs_list cc

let generated_folder = "./generated/"

let () =
  if Sys.file_exists generated_folder = false then Sys.mkdir generated_folder 777 else ();
  write_lr1_parsing_tables_to_files p a g 
    (generated_folder ^ "p.csv") (generated_folder ^ "a.csv") (generated_folder ^ "g.csv")